Deformable-DETR详解与代码解读

DETR是第一个end2end的目标检测器,不需要众多手工设计组件(anchor,iou匹配,nms后处理等),但也存在收敛慢,能处理的特征分辨率有限等缺陷。原因大概存在如下:

  • transformer在初始化时,分配给所有特征像素的注意力权重几乎均等;这就造成了模型需要长时间去学习关注真正有意义的位置,这些位置应该是稀疏的;
  • transformer在计算注意力权重时,伴随着高计算量与空间复杂度。特别是在编码器部分,与特征像素点的数量成平方级关系,因此难以处理高分辨率的特征;

Deformable DETR的工作就在于解决DETR收敛慢以及高计算复杂度问题。具体做法有:

多尺度特征 & 多尺度Embedding

在DETR中,由于计算复杂度的问题,仅仅只使用了单尺度特征,对于特征点位置信息编码使用三角函数,不同位置对应不同编码值。而在多尺度特征中,位于不同特征层的特征点可能拥有相同的(h,w)坐标,使用一套位置编码是无法区分。因此作者使用多尺度特征时,增加了scale-level embedding,用于区分不同特征层。不同于三角函数使用固定公式计算编码,scale-level embedding是可学习的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
class DeformableTransformer(nn.Module):
def __init__(self, d_model=256, nhead=8,
num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=1024, dropout=0.1,
activation="relu", return_intermediate_dec=False,
num_feature_levels=4, dec_n_points=4, enc_n_points=4,
two_stage=False, two_stage_num_proposals=300):
super().__init__()
...
# scale level embedding,对4个特征层分别附加d_model维度的embedding,用于区分query对应的具体特征层
self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model))
...
def forward(self, srcs, masks, pos_embeds, query_embed=None):
...
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
...
# (bs,c,h,w) --> (bs,h*w,c)
pos_embed = pos_embed.flatten(2).transpose(1, 2)
#与position embedding相加,且同一特征层,所有query的scale-level embedding是相同的
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
...

Deformable Attention

通俗地来讲,可变性注意力即,query不是和全局每个位置的key都计算注意力权重,而仅在全局位置中采样部分位置的key,并且value也是基于这些位置进行采样插值得到,最后将局部&稀疏的注意力权重施加在对应的value上。

如上图所示,每个query在每个head上采样K个位置,只需和这些位置的特征进行交互,不同于detr那样,每个query需要与全局位置进行交互。需要注意的是,位置偏移量$\Delta p_{mqx}$ 是由query经过全连接得到的,注意力权重也是由query经全连接层得到的,同时在K个采样点之间进行权重归一化。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class MSDeformAttn(nn.Module):
def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
...
self.n_points = n_points
# 采样点的坐标偏移量,每个query在每个head,level上都需要采样n_points个点(x,y)。
self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
# 每个query对应的所有采样点的注意力权重
self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
# 线性变换得到value
self.value_proj = nn.Linear(d_model, d_model)
self.output_proj = nn.Linear(d_model, d_model)
...

def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):

...
# (N,len_in, d_model=256)
value = self.value_proj(input_flatten)
# 将原图padding的部分用0填充
if input_padding_mask is not None:
value = value.masked_fill(input_padding_mask[..., None], float(0))
# 拆分成多个head,(N,Len_in,8,64)
value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
# 预测采样点偏移量,(N,Len_in,8,4,4,2)
sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
# 预测采样点注意力权重,(N,Len_in,8,4*4)
attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
# 权重归一化,对4个特征层分别采样的4个特征点,合计16个点,进行归一化
attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
# N, Len_q, n_heads, n_levels, n_points, 2

if reference_points.shape[-1] == 2:
# (4,2) 其中每个是w,h
offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
# 对坐标偏移量使用对应特征层的宽高进行归一化然后和参考点坐标相加得到采样点坐标
sampling_locations = reference_points[:, :, None, :, None, :] \
+ sampling_offsets / offset_normalizer[None, None, None, :, None, :]
elif reference_points.shape[-1] == 4:
# 最后一维度是4表示(cx,cy,w,h)
sampling_locations = reference_points[:, :, None, :, None, :2] \
+ sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
else:
raise ValueError(
'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
# 将注意力权重与value进行计算,是调用的self.im2col_step函数
output = MSDeformAttnFunction.apply(
value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
# 做线性变换得到最终输出结果
output = self.output_proj(output)
return output

Deformable Transformer

与DETR大体一致,主要区别在于用Deformable Attention替换了Encoder中的self-attn和Decoder中的cross-attn。

Encoder前处理
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class DeformableTransformer(nn.Module):

...

def forward(self, srcs, masks, pos_embeds, query_embed=None):
assert self.two_stage or query_embed is not None

# prepare input for encoder
src_flatten = []
mask_flatten = []
lvl_pos_embed_flatten = []
spatial_shapes = []
for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)):
bs, c, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
src = src.flatten(2).transpose(1, 2)
mask = mask.flatten(1)
pos_embed = pos_embed.flatten(2).transpose(1, 2)
lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1)
lvl_pos_embed_flatten.append(lvl_pos_embed)
src_flatten.append(src)
mask_flatten.append(mask)

#多尺度特征flatten后进行拼接
src_flatten = torch.cat(src_flatten, 1)
#多尺度mask图flatten后进行拼接
mask_flatten = torch.cat(mask_flatten, 1)
# 多尺度位置信息flatten后进行拼接
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1)
# 记录每个特征图的尺度信息
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device)
# 记录每个尺度特征图拼接后的起始索引
level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1]))
# 计算各个尺度特征图中非padding部分的边长占其边长的比例
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)

# encoder
memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten)
...

在DeformableTransformer的前向处理过程中,首先会对多尺度相关的元素进行flatten,这些输入元素包括:多尺度特征图、各尺度特征图对应的mask(指示哪些部分属于padding)、各尺度特征图对应的位置信息(position embedding + scale-level embedding),另外还有些辅助信息,比如:各尺度特征图的宽高、不同尺度特征对应于被flatten的那个维度的起始索引、各尺度特征图中非padding部分的边长占其边长的比例。

Encoder编码

经过Encoder前处理之后的信息就会经过Encoder进行编码,输出memory。下面代码展示的是Encoder的处理过程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class DeformableTransformerEncoder(nn.Module):
def __init__(self, encoder_layer, num_layers):
super().__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers

@staticmethod
def get_reference_points(spatial_shapes, valid_ratios, device):
reference_points_list = []
for lvl, (H_, W_) in enumerate(spatial_shapes):

ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))
ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_)
ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_)
ref = torch.stack((ref_x, ref_y), -1)
reference_points_list.append(ref)
reference_points = torch.cat(reference_points_list, 1)
reference_points = reference_points[:, :, None] * valid_ratios[:, None]
return reference_points

def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None):
output = src
# 参考点初始化,以0.5为步长,在特征图上密集采样所有点作为初始参考点
reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device)
for _, layer in enumerate(self.layers):
output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask)

return output

输出memory(编码后的特征表示),shape是 (bs, h_lvl1w_lvl1+h_lvl2w_lvl2+.., c=256),其中h_lvli和w_lvli分别代表第i层特征图的高和宽,于是第二个维度就是所有特征点的数量。编码后,特征的最后一个维度hidden_dim=256.

Decoder前处理

对encoder的输出进行处理,得到参考点reference_points,需要说明下,在2-stage模式下,参考点和输入到Decoder的object query及query embedding的生成方式和形式会有所不同。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class DeformableTransformer(nn.Module):
...
def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):
N_, S_, C_ = memory.shape
base_scale = 4.0
proposals = []
_cur = 0
for lvl, (H_, W_) in enumerate(spatial_shapes):
mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1)
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)

grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device))
grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)

scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2)
grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale
wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl)
proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
proposals.append(proposal)
_cur += (H_ * W_)
output_proposals = torch.cat(proposals, 1)
output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
output_proposals = torch.log(output_proposals / (1 - output_proposals))
output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf'))
output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf'))

output_memory = memory
output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
output_memory = self.enc_output_norm(self.enc_output(output_memory))
return output_memory, output_proposals

def forward(self, srcs, masks, pos_embeds, query_embed=None):
...
# encoder
memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten)

# prepare input for decoder
bs, _, c = memory.shape
if self.two_stage:
# 生成proposal,对encoder的输出进行处理(全连接层+归一化),output_proposals对应于特征图上各个初始参考点位置(固定的)
output_memory, output_proposals = self.gen_encoder_output_proposals(memory, mask_flatten, spatial_shapes)


# hack implementation for two-stage Deformable DETR
# 借助于decoder的最后一层的class_embed和bbox_embed获取分数和proposal
enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory)
# bbox_embed预测的是相对于初始参考点位置的偏移量,所以需要加上初始参考点位置
enc_outputs_coord_unact = self.decoder.bbox_embed[self.decoder.num_layers](output_memory) + output_proposals

topk = self.two_stage_num_proposals
# 选取分数topk的proposal
topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
topk_coords_unact = torch.gather(enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4))
topk_coords_unact = topk_coords_unact.detach()
reference_points = topk_coords_unact.sigmoid()
init_reference_out = reference_points
pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact)))
query_embed, tgt = torch.split(pos_trans_out, c, dim=2)
else:
query_embed, tgt = torch.split(query_embed, c, dim=1)
query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1)
tgt = tgt.unsqueeze(0).expand(bs, -1, -1)
reference_points = self.reference_points(query_embed).sigmoid()
init_reference_out = reference_points

# decoder
hs, inter_references = self.decoder(tgt, reference_points, memory,
spatial_shapes, level_start_index, valid_ratios, query_embed, mask_flatten)

inter_references_out = inter_references
if self.two_stage:
return hs, init_reference_out, inter_references_out, enc_outputs_class, enc_outputs_coord_unact
return hs, init_reference_out, inter_references_out, None, None
  • 如果是two-stage 模式下,参考点由Encoder预测topk得分最高的proposal box(这时的参考点是4d,bbox形式),然后对参考点进行position embedding来生成Decoder需要的object query和对应的query embedding;
  • 非two-stage模式下,Decoder的 object query(target )和 query embedding 就是预设的embedding,然后将query embedding经过全连接层输出2d参考点,这时的参考点是归一化的中心坐标形式。

另外,两种情况下生成的参考点数量可能不同:2-stage时是有top-k(作者设置为300)个,而1-stage时是num_queries(作者也设置为300)个,也就是和object query的数量一致(可以理解为,此时参考点就是object query本身的位置)。

Decoder解码

这里与Transformer中主要的区别在于使用可变形注意力替代了原生的交叉注意力。类似地,每层的解码过程是self-attention+cross-attention+ffn,下一层输入的object query是上一层输出的解码特征。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class DeformableTransformerDecoder(nn.Module):
...
def forward(self, tgt, reference_points, src, src_spatial_shapes, src_level_start_index, src_valid_ratios,
query_pos=None, src_padding_mask=None):
output = tgt

intermediate = []
intermediate_reference_points = []

for lid, layer in enumerate(self.layers):
if reference_points.shape[-1] == 4:
reference_points_input = reference_points[:, :, None] \
* torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None]
else:
assert reference_points.shape[-1] == 2
reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None]
output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index, src_padding_mask)

# hack implementation for iterative bounding box refinement
# 使用iterative bbox refinement
if self.bbox_embed is not None:
# 得到相对初始参考点的偏移量
tmp = self.bbox_embed[lid](output)
# 得到归一化坐标点
if reference_points.shape[-1] == 4:
new_reference_points = tmp + inverse_sigmoid(reference_points)
new_reference_points = new_reference_points.sigmoid()
else:
assert reference_points.shape[-1] == 2
new_reference_points = tmp
new_reference_points[..., :2] = tmp[..., :2] + inverse_sigmoid(reference_points)
new_reference_points = new_reference_points.sigmoid()
# 在输入下一层之前取消梯度,将当前层的预测bbox作为下一层的初始参考点
reference_points = new_reference_points.detach()

if self.return_intermediate:
intermediate.append(output)
intermediate_reference_points.append(reference_points)

if self.return_intermediate:
return torch.stack(intermediate), torch.stack(intermediate_reference_points)

return output, reference_points

进行Decoder解码完之后就是接class_embed和bbox_embed得到最后box分数和坐标。在上面需要注意的一点是,每次refine后的bbox梯度是不会传递到下一层

总结

  1. 相比于DETR,使用了多尺度特征+scale-level embedding,用于区分不同特征层;
  2. 使用了多尺度可变形注意力替代Encoder中的selfattn和Decoder中的crossattn,减小计算量;
  3. 引入了参考点,类似引入先验知识;
  4. 设计了两阶段模式和iteraitive box refinement策略;
  5. 检测头回归分支预测是bbox相对参考点的偏移量而非绝对坐标值;